{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k91nHOvT6I11"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "v1_OyHN36JyC"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C5PO5cHLFDT1"
      },
      "source": [
        "# On-device Text-to-Image Search with TensorFlow Lite Searcher Library\n",
        "\n",
        "In this colab, we showcase an end to end example of how to train an image-text dual encoder model and how to perform retrieval with TFLite Searcher Library. We are going to use the [COCO 2014](https://cocodataset.org/#home) dataset, and in the end you'll be able to retrieve images using a text description.\n",
        "\n",
        "First, we need to encode the images into high-dimensional vectors. Then we index them with [Model Maker Searcher API](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/). During inference, a TFLite text embedder encodes the text query into another high-dimensional vector in the same embedding space, and invokes the [on-device ScaNN searcher](https://github.com/tensorflow/tflite-support/tree/master/tensorflow_lite_support/scann_ondevice) to retrieve similar images.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2Tc6uMrczn4g"
      },
      "source": [
        "You can download the pre-trained searcher model packed with ScaNN index from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite_support/searcher/text_to_image_blogpost/searcher_model.tflite) and skip to [inference](#scrollTo=EeZwqEnxW5Xl). Be sure to name it `searcher_model.tflite` and upload it to colab under the current working directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8KBOd6FudlqM"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U tensorflow tensorflow-hub tensorflow-addons\n",
        "!pip install -q -U tflite-support\n",
        "!pip install -q -U tflite-model-maker\n",
        "!pip install -q -U tensorflow-text==2.10.0b2\n",
        "!sudo apt-get -qq install libportaudio2  # Needed by tflite-support"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SXYcCLchJXil"
      },
      "source": [
        "Note you might need to restart the runtime after installation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CDDQxg1RpGPZ"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import math\n",
        "import os\n",
        "import pickle\n",
        "import random\n",
        "import shutil\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "import tensorflow.compat.v1 as tf1\n",
        "from tensorflow.keras import layers\n",
        "import tensorflow_addons as tfa\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_text as text\n",
        "from tensorflow_text.python.ops import fast_sentencepiece_tokenizer as sentencepiece_tokenizer\n",
        "\n",
        "# Suppressing tf.hub warnings\n",
        "tf.get_logger().setLevel('ERROR')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6yNYV-uBHtRY"
      },
      "outputs": [],
      "source": [
        "DATASET_DIR = 'datasets'\n",
        "CAPTION_URL = 'http://images.cocodataset.org/annotations/annotations_trainval2014.zip'\n",
        "TRAIN_IMAGE_URL = 'http://images.cocodataset.org/zips/train2014.zip'\n",
        "VALID_IMAGE_URL = 'http://images.cocodataset.org/zips/val2014.zip'\n",
        "TRAIN_IMAGE_DIR = os.path.join(DATASET_DIR, 'train2014')\n",
        "VALID_IMAGE_DIR = os.path.join(DATASET_DIR, 'val2014')\n",
        "TRAIN_IMAGE_PREFIX = 'COCO_train2014_'\n",
        "VALID_IMAGE_PREFIX = 'COCO_val2014_'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "niIHzEnzJ8HR"
      },
      "outputs": [],
      "source": [
        "IMAGE_SIZE = (384, 384)\n",
        "EFFICIENT_NET_URL = 'https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/feature_vector/2'\n",
        "UNIVERSAL_SENTENCE_ENCODER_URL = 'https://tfhub.dev/google/universal-sentence-encoder-lite/2'\n",
        "\n",
        "BATCH_SIZE = 256\n",
        "NUM_EPOCHS = 10\n",
        "SEQ_LENGTH = 128\n",
        "EMB_SIZE = 128"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tPvDrjQ9FBNw"
      },
      "source": [
        "## Get COCO dataset\n",
        "\n",
        "We are not using Tensorflow Dataset to get the [coco_captions](https://www.tensorflow.org/datasets/catalog/coco_captions) dataset due to disk space concerns. The following code will download and process the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FrQXA95HGzYN"
      },
      "outputs": [],
      "source": [
        "#@title Functions for downloading and parsing annotations.\n",
        "\n",
        "def parse_annotation_json(json_path):\n",
        "  # Assuming the json file is already downloaded.\n",
        "  with open(json_path, 'r') as f:\n",
        "    json_obj = json.load(f)\n",
        "\n",
        "  # Parsing out the following information from the annotation json: the COCO\n",
        "  # image id and their corresponding flickr post id, as well as the captions.\n",
        "  mapping = dict()\n",
        "  for caption in json_obj['annotations']:\n",
        "    image_id = caption['image_id']\n",
        "    if image_id not in mapping:\n",
        "      mapping[image_id] = [[]]\n",
        "    mapping[image_id][0].append(caption['caption'])\n",
        "  for image in json_obj['images']:\n",
        "    # The flickr url here is the CDN url. We need to split it to get the post\n",
        "    # id.\n",
        "    flickr_url = image['flickr_url']\n",
        "    url_parts = flickr_url.split('/')\n",
        "    flickr_id = url_parts[-1].split('_')[0]\n",
        "    mapping[image['id']].append(flickr_id)\n",
        "  return list(mapping.items())\n",
        "\n",
        "\n",
        "def get_train_valid_captions():\n",
        "  # Parse and cache the annotation for train and valid\n",
        "  train_pickle_path = os.path.join(DATASET_DIR, 'train_captions.pickle')\n",
        "  valid_pickle_path = os.path.join(DATASET_DIR, 'valid_captions.pickle')\n",
        "\n",
        "  if not os.path.exists(train_pickle_path) or not os.path.exists(\n",
        "      valid_pickle_path):\n",
        "    # Parse and cache the annotations if they don't exist\n",
        "    annotation_zip = tf.keras.utils.get_file(\n",
        "        'annotations.zip',\n",
        "        cache_dir=os.path.abspath('.'),\n",
        "        cache_subdir=os.path.join(DATASET_DIR, 'tmp'),\n",
        "        origin=CAPTION_URL,\n",
        "        extract=True,\n",
        "    )\n",
        "    os.remove(annotation_zip)\n",
        "    train_img_cap = parse_annotation_json(\n",
        "        os.path.join(DATASET_DIR, 'tmp', 'annotations',\n",
        "                     'captions_train2014.json'))\n",
        "    valid_img_cap = parse_annotation_json(\n",
        "        os.path.join(DATASET_DIR, 'tmp', 'annotations',\n",
        "                     'captions_val2014.json'))\n",
        "    with open(train_pickle_path, 'wb') as f:\n",
        "      pickle.dump(train_img_cap, f)\n",
        "    with open(valid_pickle_path, 'wb') as f:\n",
        "      pickle.dump(valid_img_cap, f)\n",
        "    shutil.rmtree(os.path.join(DATASET_DIR, 'tmp'))\n",
        "  else:\n",
        "    # Load the cached annotations\n",
        "    with open(train_pickle_path, 'rb') as f:\n",
        "      train_img_cap = pickle.load(f)\n",
        "    with open(valid_pickle_path, 'rb') as f:\n",
        "      valid_img_cap = pickle.load(f)\n",
        "  return train_img_cap, valid_img_cap"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "OTOPBL57a74w"
      },
      "outputs": [],
      "source": [
        "#@title Functions for downloading the images and create the dataset.\n",
        "\n",
        "def get_sentencepiece_tokenizer_in_tf2():\n",
        "  # The universal sentence encoder model from TFHub is in TF1 Module format. We\n",
        "  # need to directly access the asset_paths to get the sentencepiece tokenizer\n",
        "  # proto path.\n",
        "  module = hub.load(UNIVERSAL_SENTENCE_ENCODER_URL)\n",
        "  spm_path = module.asset_paths[0].asset_path.numpy()\n",
        "  with tf.io.gfile.GFile(spm_path, mode='rb') as f:\n",
        "    return sentencepiece_tokenizer.FastSentencepieceTokenizer(f.read())\n",
        "\n",
        "\n",
        "def prepare_dataset(id_image_info_list,\n",
        "                    image_file_prefix,\n",
        "                    image_dir,\n",
        "                    image_zip_url,\n",
        "                    shuffle=False):\n",
        "  # Download and unzip the dataset if it's not there already.\n",
        "  if not os.path.exists(image_dir):\n",
        "    image_zip = tf.keras.utils.get_file(\n",
        "        'image.zip',\n",
        "        cache_dir=os.path.abspath('.'),\n",
        "        cache_subdir=os.path.join(DATASET_DIR),\n",
        "        origin=image_zip_url,\n",
        "        extract=True,\n",
        "    )\n",
        "    os.remove(image_zip)\n",
        "\n",
        "  # Convert the lists into tensors so that we can index into it in the dataset\n",
        "  # transformations later.\n",
        "  coco_ids, image_info = zip(*id_image_info_list)\n",
        "  captions, flickr_ids = zip(*image_info)\n",
        "  file_names = list(\n",
        "      map(\n",
        "          lambda id: os.path.join(image_dir, '%s%012d.jpg' %\n",
        "                                  (image_file_prefix, id)), coco_ids))\n",
        "  coco_ids_tensor = tf.constant(coco_ids)\n",
        "  captions_tensor = tf.ragged.constant(captions)\n",
        "  file_names_tensor = tf.constant(file_names)\n",
        "  flickr_ids_tensor = tf.constant(flickr_ids)\n",
        "\n",
        "  # The initial dataset only contains the index. This is to make sure the\n",
        "  # dataset has a known size.\n",
        "  dataset = tf.data.Dataset.range(len(coco_ids))\n",
        "\n",
        "  sp = get_sentencepiece_tokenizer_in_tf2()\n",
        "\n",
        "  def _load_image_and_select_caption(i):\n",
        "    image_id = coco_ids_tensor[i]\n",
        "    captions = captions_tensor[i]\n",
        "    image_path = file_names_tensor[i]\n",
        "    flickr_id = flickr_ids_tensor[i]\n",
        "    image = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=3)\n",
        "\n",
        "    # Randomly select one caption from the many captions we have for each image\n",
        "    caption_idx = tf.random.uniform((1,),\n",
        "                                    minval=0,\n",
        "                                    maxval=tf.shape(captions)[0],\n",
        "                                    dtype=tf.int32)[0]\n",
        "    caption = captions[caption_idx]\n",
        "    caption = tf.sparse.from_dense(sp.tokenize(caption))\n",
        "    example = {\n",
        "        'image': image,\n",
        "        'image_id': image_id,\n",
        "        'caption': caption,\n",
        "        'flickr_id': flickr_id\n",
        "    }\n",
        "    return example\n",
        "\n",
        "  def _resize_image(example):\n",
        "    # Efficient net requires the pixels to be in range of [0, 1].\n",
        "    example['image'] = tf.image.resize(example['image'], size=IMAGE_SIZE) / 255\n",
        "    return example\n",
        "\n",
        "  dataset = (\n",
        "      # Load the images from disk and decode them into numpy arrays.\n",
        "      dataset.map(\n",
        "          _load_image_and_select_caption,\n",
        "          num_parallel_calls=tf.data.AUTOTUNE,\n",
        "          deterministic=not shuffle)\n",
        "\n",
        "      # Resizing image is slow. We put the stage into a separate map so that it\n",
        "      # could get more threads to not be the bottleneck.\n",
        "      .map(\n",
        "          _resize_image,\n",
        "          num_parallel_calls=tf.data.AUTOTUNE,\n",
        "          deterministic=not shuffle))\n",
        "\n",
        "  if shuffle:\n",
        "    dataset = dataset.shuffle(BATCH_SIZE * 10)\n",
        "\n",
        "  dataset = dataset.batch(BATCH_SIZE)\n",
        "  return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kzpigw9ozZOM"
      },
      "source": [
        "Download the datasets and preprocess them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 16948,
          "status": "ok",
          "timestamp": 1651885239693,
          "user": {
            "displayName": "Zonglin Li",
            "userId": "11843710831668693042"
          },
          "user_tz": 240
        },
        "id": "pHbgdBfFWmtz",
        "outputId": "38b5e03b-4c19-430f-af0a-48019383540e"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from http://images.cocodataset.org/annotations/annotations_trainval2014.zip\n",
            "252878848/252872794 [==============================] - 8s 0us/step\n",
            "252887040/252872794 [==============================] - 8s 0us/step\n",
            "Train number of images: 82783\n",
            "Valid number of images: 40504\n",
            "COCO image id: 318556\n",
            "Captions: ['A very clean and well decorated empty bathroom', 'A blue and white bathroom with butterfly themed wall tiles.', 'A bathroom with a border of butterflies and blue paint on the walls above it.', 'An angled view of a beautifully decorated bathroom.', 'A clock that blends in with the wall hangs in a bathroom. ']\n",
            "Flickr post url: http://flickr.com/photo.gne?id=3378902101\n"
          ]
        }
      ],
      "source": [
        "# We parse the caption json files first.\n",
        "train_img_cap, valid_img_cap = get_train_valid_captions()\n",
        "print(f'Train number of images: {len(train_img_cap)}')\n",
        "print(f'Valid number of images: {len(valid_img_cap)}')\n",
        "\n",
        "example = train_img_cap[0]\n",
        "print(f'COCO image id: {example[0]}')\n",
        "print(f'Captions: {example[1][0]}')\n",
        "print(f'Flickr post url: http://flickr.com/photo.gne?id={example[1][1]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 775219,
          "status": "ok",
          "timestamp": 1651886014906,
          "user": {
            "displayName": "Zonglin Li",
            "userId": "11843710831668693042"
          },
          "user_tz": 240
        },
        "id": "Ke6EeKAqj1vB",
        "outputId": "4c550552-270b-435a-e8d7-a73c92da9ef9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from http://images.cocodataset.org/zips/val2014.zip\n",
            "6645014528/6645013297 [==============================] - 183s 0us/step\n",
            "6645022720/6645013297 [==============================] - 183s 0us/step\n",
            "Downloading data from http://images.cocodataset.org/zips/train2014.zip\n",
            "13510574080/13510573713 [==============================] - 412s 0us/step\n",
            "13510582272/13510573713 [==============================] - 412s 0us/step\n"
          ]
        }
      ],
      "source": [
        "# Shuffle both the train and validation sets\n",
        "random.shuffle(valid_img_cap)\n",
        "random.shuffle(train_img_cap)\n",
        "\n",
        "# We randomly sample 5000 image-caption pairs from validation set for validation\n",
        "# during training, to match the setup of\n",
        "# https://www.tensorflow.org/datasets/catalog/coco_captions. However, when\n",
        "# generating the retrieval database later on, we will use all the images in both\n",
        "# validation and training splits.\n",
        "valid_dataset = prepare_dataset(\n",
        "    valid_img_cap[:5000],\n",
        "    VALID_IMAGE_PREFIX,\n",
        "    VALID_IMAGE_DIR,\n",
        "    VALID_IMAGE_URL)\n",
        "train_dataset = prepare_dataset(\n",
        "    train_img_cap,\n",
        "    TRAIN_IMAGE_PREFIX,\n",
        "    TRAIN_IMAGE_DIR,\n",
        "    TRAIN_IMAGE_URL,\n",
        "    shuffle=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g11BA6ycJAru"
      },
      "source": [
        "## Define models"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jj2aulT90vgO"
      },
      "source": [
        "The image encoder and text encoder may not output the embeddings with the same amount of dimensions. We need to project them into the same embedding space"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k6tSQPkQBfht"
      },
      "outputs": [],
      "source": [
        "def project_embeddings(embeddings, num_projection_layers, projection_dims,\n",
        "                       dropout_rate):\n",
        "\n",
        "  projected_embeddings = layers.Dense(units=projection_dims)(embeddings)\n",
        "  for _ in range(num_projection_layers):\n",
        "    x = tf.nn.relu(projected_embeddings)\n",
        "    x = layers.Dense(projection_dims)(x)\n",
        "    x = layers.Dropout(dropout_rate)(x)\n",
        "    x = layers.Add()([projected_embeddings, x])\n",
        "    projected_embeddings = layers.LayerNormalization()(x)\n",
        "\n",
        "  # Finally we L2 normalize the embeddings. In general, L2 normalized embeddings\n",
        "  # are easier to retrieve.\n",
        "  projected_embeddings = tf.math.l2_normalize(projected_embeddings, axis=1)\n",
        "  return projected_embeddings"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U64G7g3pq5bH"
      },
      "outputs": [],
      "source": [
        "def create_image_encoder(num_projection_layers,\n",
        "                         projection_dims,\n",
        "                         dropout_rate,\n",
        "                         trainable=False):\n",
        "  efficient_net = hub.KerasLayer(EFFICIENT_NET_URL, trainable=trainable)\n",
        "  inputs = layers.Input(shape=IMAGE_SIZE + (3,), name='image_input')\n",
        "  embeddings = efficient_net(inputs)\n",
        "  outputs = project_embeddings(embeddings, num_projection_layers,\n",
        "                               projection_dims, dropout_rate)\n",
        "  return keras.Model(inputs, outputs, name='image_encoder')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ObnLD9KM0uy3"
      },
      "source": [
        "We use [Universal Sentence Encoder](https://tfhub.dev/google/universal-sentence-encoder-lite/2), a SOTA sentence embedding model, as the text encoder base model. The TFHub lite version is a TF1 saved model. To make it work well in TF2 and later TFLite conversion, we create two models, one is the frozen universal sentence encoder, and the other is the trainable projection layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eJ0MLxdKcFb1"
      },
      "outputs": [],
      "source": [
        "def create_text_encoder():\n",
        "  encoder = hub.KerasLayer(\n",
        "      UNIVERSAL_SENTENCE_ENCODER_URL,\n",
        "      name='universal_sentence_encoder',\n",
        "      signature='default')\n",
        "  encoder.trainable = False\n",
        "  inputs = layers.Input(\n",
        "      shape=(None,), dtype=tf.int64, name='text_input', sparse=True)\n",
        "  embeddings = encoder(\n",
        "      dict(\n",
        "          values=inputs.values,\n",
        "          indices=inputs.indices,\n",
        "          dense_shape=inputs.dense_shape))\n",
        "  return keras.Model(inputs, embeddings, name='text_encoder')\n",
        "\n",
        "\n",
        "def create_text_embedder_projection(input_dim, num_projection_layers,\n",
        "                                    projection_dims, dropout_rate):\n",
        "  inputs = layers.Input(shape=(input_dim), dtype=tf.float32, name='text_input')\n",
        "  outputs = project_embeddings(inputs, num_projection_layers, projection_dims,\n",
        "                               dropout_rate)\n",
        "  return keras.Model(inputs, outputs, name='projection_layers')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yHX9RYZ62ZmC"
      },
      "source": [
        "This dual encoder model is derived from this [Keras post](https://keras.io/examples/nlp/nl_image_search/)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2jVV2cHKCWIm"
      },
      "outputs": [],
      "source": [
        "class DualEncoder(keras.Model):\n",
        "\n",
        "  def __init__(self,\n",
        "               text_encoder,\n",
        "               text_encoder_projection,\n",
        "               image_encoder,\n",
        "               temperature,\n",
        "               **kwargs):\n",
        "    super(DualEncoder, self).__init__(**kwargs)\n",
        "    self.text_encoder = text_encoder\n",
        "    self.text_encoder_projection = text_encoder_projection\n",
        "    self.image_encoder = image_encoder\n",
        "\n",
        "    # Temperature controls the contrast of softmax output. In general, a low\n",
        "    # temperature increases the contrast and a high temperature decreases it.\n",
        "    self.temperature = temperature\n",
        "    self.loss_tracker = keras.metrics.Mean(name='loss')\n",
        "\n",
        "  @property\n",
        "  def metrics(self):\n",
        "    return [self.loss_tracker]\n",
        "\n",
        "  def call(self, features, training=False):\n",
        "    # If there are two GPUs present, we use one of them for image encoder and\n",
        "    # one for text encoder. If there's only one GPU then they will be trained on\n",
        "    # the same GPU.\n",
        "    with tf.device('/gpu:0'):\n",
        "      caption_embeddings = self.text_encoder(\n",
        "          features['caption'], training=False)\n",
        "      caption_embeddings = self.text_encoder_projection(\n",
        "          caption_embeddings, training=training)\n",
        "    with tf.device('/gpu:1'):\n",
        "      image_embeddings = self.image_encoder(\n",
        "          features['image'], training=training)\n",
        "    return caption_embeddings, image_embeddings\n",
        "\n",
        "  def compute_loss(self, caption_embeddings, image_embeddings):\n",
        "    # Computing the loss with dot product similarity between image and text\n",
        "    # embeddings.\n",
        "    logits = (\n",
        "        tf.matmul(caption_embeddings, image_embeddings, transpose_b=True) /\n",
        "        self.temperature)\n",
        "    images_similarity = tf.matmul(\n",
        "        image_embeddings, image_embeddings, transpose_b=True)\n",
        "    captions_similarity = tf.matmul(\n",
        "        caption_embeddings, caption_embeddings, transpose_b=True)\n",
        "\n",
        "    # The targets is the mean of the self-similarity of the captions and images.\n",
        "    # This is more lenient to the similar examples appeared in the same batch.\n",
        "    targets = keras.activations.softmax(\n",
        "        (captions_similarity + images_similarity) / (2 * self.temperature))\n",
        "    captions_loss = keras.losses.categorical_crossentropy(\n",
        "        y_true=targets, y_pred=logits, from_logits=True)\n",
        "    images_loss = keras.losses.categorical_crossentropy(\n",
        "        y_true=tf.transpose(targets),\n",
        "        y_pred=tf.transpose(logits),\n",
        "        from_logits=True)\n",
        "    return (captions_loss + images_loss) / 2\n",
        "\n",
        "  def train_step(self, features):\n",
        "    with tf.GradientTape() as tape:\n",
        "      # Forward pass\n",
        "      caption_embeddings, image_embeddings = self(features, training=True)\n",
        "      loss = self.compute_loss(caption_embeddings, image_embeddings)\n",
        "\n",
        "    # Backward pass\n",
        "    gradients = tape.gradient(loss, self.trainable_variables)\n",
        "    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))\n",
        "    self.loss_tracker.update_state(loss)\n",
        "    return {'loss': self.loss_tracker.result()}\n",
        "\n",
        "  def test_step(self, features):\n",
        "    caption_embeddings, image_embeddings = self(features, training=False)\n",
        "    loss = self.compute_loss(caption_embeddings, image_embeddings)\n",
        "    self.loss_tracker.update_state(loss)\n",
        "    return {'loss': self.loss_tracker.result()}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9POw0Ye4x-XR"
      },
      "source": [
        "## Train the Dual Encoder model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y9Wz75GfxN6L"
      },
      "source": [
        "Load the models from Tensorflow Hub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JHsZiXEK_ZwO"
      },
      "outputs": [],
      "source": [
        "# The text embedder consists of two models. One is the frozen base universal\n",
        "# sentence encoder, and the other is the trainable projection layer. We are\n",
        "# doing this instead of one model to make later TFLite model conversion easier.\n",
        "text_encoder = create_text_encoder()\n",
        "projection_layers = create_text_embedder_projection(\n",
        "    input_dim=512,  # Universal sentence encoder output has 512 dimensions\n",
        "    num_projection_layers=1,\n",
        "    projection_dims=EMB_SIZE,\n",
        "    dropout_rate=0.1)\n",
        "\n",
        "image_encoder = create_image_encoder(\n",
        "    num_projection_layers=1, projection_dims=EMB_SIZE, dropout_rate=0.1)\n",
        "\n",
        "dual_encoder = DualEncoder(\n",
        "    text_encoder, projection_layers, image_encoder, temperature=0.05)\n",
        "dual_encoder.compile(\n",
        "    optimizer=tfa.optimizers.AdamW(learning_rate=0.001, weight_decay=0.001))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tj8v8wq6xUbS"
      },
      "source": [
        "Train the dual encoder model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 12338131,
          "status": "ok",
          "timestamp": 1651898372226,
          "user": {
            "displayName": "Zonglin Li",
            "userId": "11843710831668693042"
          },
          "user_tz": 240
        },
        "id": "Q1a4h5DNCaBq",
        "outputId": "62b2a90b-fcf3-4e5f-fff6-63ed1047b57a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/10\n",
            "324/324 [==============================] - 1387s 4s/step - loss: 1.8785 - val_loss: 1.5041 - lr: 0.0010\n",
            "Epoch 2/10\n",
            "324/324 [==============================] - 1345s 4s/step - loss: 1.4041 - val_loss: 1.3767 - lr: 0.0010\n",
            "Epoch 3/10\n",
            "324/324 [==============================] - 1351s 4s/step - loss: 1.3275 - val_loss: 1.3518 - lr: 0.0010\n",
            "Epoch 4/10\n",
            "324/324 [==============================] - 1364s 4s/step - loss: 1.2792 - val_loss: 1.3365 - lr: 9.0484e-04\n",
            "Epoch 5/10\n",
            "324/324 [==============================] - 1353s 4s/step - loss: 1.2511 - val_loss: 1.3124 - lr: 8.1873e-04\n",
            "Epoch 6/10\n",
            "324/324 [==============================] - 1352s 4s/step - loss: 1.2366 - val_loss: 1.2991 - lr: 7.4082e-04\n",
            "Epoch 7/10\n",
            "324/324 [==============================] - 1359s 4s/step - loss: 1.2266 - val_loss: 1.2935 - lr: 6.7032e-04\n",
            "Epoch 8/10\n",
            "324/324 [==============================] - 1354s 4s/step - loss: 1.2154 - val_loss: 1.3117 - lr: 6.0653e-04\n",
            "Epoch 9/10\n",
            "324/324 [==============================] - 1359s 4s/step - loss: 1.2220 - val_loss: 1.3212 - lr: 5.4881e-04\n",
            "Training completed. Saving image and text encoders.\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:absl:Found untraced functions such as restored_function_body, restored_function_body, restored_function_body, restored_function_body, restored_function_body while saving (showing 5 of 622). These functions will not be directly callable after loading.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Models are saved.\n"
          ]
        }
      ],
      "source": [
        "# We train the first three epochs with the learning rate of 0.001 and\n",
        "# decrease it exponentially later on.\n",
        "def lr_scheduler(epoch, lr):\n",
        "  if epoch \u003c 3:\n",
        "    return lr\n",
        "  else:\n",
        "    return max(lr * tf.math.exp(-0.1), lr * 0.1)\n",
        "\n",
        "# In colab, training takes roughly 4s per step, around 24 mins per epoch\n",
        "early_stopping = tf.keras.callbacks.EarlyStopping(\n",
        "    monitor='val_loss', patience=2, restore_best_weights=True)\n",
        "history = dual_encoder.fit(\n",
        "    train_dataset,\n",
        "    epochs=NUM_EPOCHS,\n",
        "    validation_data=valid_dataset,\n",
        "    callbacks=[\n",
        "        tf.keras.callbacks.LearningRateScheduler(lr_scheduler), early_stopping\n",
        "    ],\n",
        "    max_queue_size=2,\n",
        ")\n",
        "\n",
        "# Save the models. We are not going to save the text_encoder since it's frozen\n",
        "# and the TF2 saved model for text_encoder has problems converting to TFLite.\n",
        "print('Training completed. Saving image and text encoders.')\n",
        "dual_encoder.image_encoder.save('image_encoder')\n",
        "dual_encoder.text_encoder_projection.save('text_encoder_projection')\n",
        "print('Models are saved.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hcyej5mYxX39"
      },
      "source": [
        "## Create the text-to-image Searcher model using Model Maker"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bp0qBKkyu4jA"
      },
      "source": [
        "### Generate image embeddings"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dXdecbiY2NSs"
      },
      "source": [
        "Load the valid and train dataset one more time. This time we are not going to shuffle the train split and we use the whole validataion split. Since images are not loaded until they are iterated, creating the datasets should be cheap."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PA_X283yMUsR"
      },
      "outputs": [],
      "source": [
        "combined_valid_dataset = prepare_dataset(\n",
        "    valid_img_cap,\n",
        "    VALID_IMAGE_PREFIX,\n",
        "    VALID_IMAGE_DIR,\n",
        "    VALID_IMAGE_URL)\n",
        "deterministic_train_dataset = prepare_dataset(\n",
        "    train_img_cap,\n",
        "    TRAIN_IMAGE_PREFIX,\n",
        "    TRAIN_IMAGE_DIR,\n",
        "    TRAIN_IMAGE_URL)\n",
        "\n",
        "all_combined = deterministic_train_dataset.concatenate(combined_valid_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lK8FdFcx2siR"
      },
      "source": [
        "Create the metadata (image file names and the flickr post id) from the dataset. This will later be packed into the TFLite model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M98I3IaHRBEl"
      },
      "outputs": [],
      "source": [
        "def create_metadata(image_file_prefix, image_dir):\n",
        "\n",
        "  def _create_metadata(image_info):\n",
        "    # This is the same way we generated the image paths in the prepare_dataset\n",
        "    # function above\n",
        "    coco_id = image_info[0]\n",
        "    flickr_id = image_info[1][1]\n",
        "    return ('%s_%s' %\n",
        "            (flickr_id,\n",
        "             os.path.join(image_dir, '%s%012d.jpg' %\n",
        "                          (image_file_prefix, coco_id)))).encode('utf-8')\n",
        "\n",
        "  return _create_metadata\n",
        "\n",
        "\n",
        "# We don't store the images in the index file, as that would be too big. We only\n",
        "# store the image path and the corresponding Flickr id.\n",
        "metadata = list(\n",
        "    map(create_metadata(TRAIN_IMAGE_PREFIX, TRAIN_IMAGE_DIR), train_img_cap))\n",
        "metadata.extend(\n",
        "    map(create_metadata(VALID_IMAGE_PREFIX, VALID_IMAGE_DIR), valid_img_cap))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EYUk_D2S24Fg"
      },
      "source": [
        "Generate the embeddings for all the images we have. We do it in Tensorflow with GPU instead of Model Maker. Again, these will be packed into the TFLite model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 1147631,
          "status": "ok",
          "timestamp": 1651899528619,
          "user": {
            "displayName": "Zonglin Li",
            "userId": "11843710831668693042"
          },
          "user_tz": 240
        },
        "id": "Vk--b8EgQhHo",
        "outputId": "e4ea27c3-175e-43bd-d40e-a2668a8c9298"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "483/483 [==============================] - 1147s 2s/step\n",
            "Embedding matrix shape: (123287, 128)\n"
          ]
        }
      ],
      "source": [
        "# Image encoder takes one input named `image_input` so we remove other values in\n",
        "# the dataset.\n",
        "image_dataset = all_combined.map(\n",
        "    lambda example: {'image_input': example['image']})\n",
        "image_embeddings = dual_encoder.image_encoder.predict(image_dataset, verbose=1)\n",
        "print(f'Embedding matrix shape: {image_embeddings.shape}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6Dzye66Xc8vE"
      },
      "source": [
        "### Convert text embedder to TFLite"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IBef6gzm3AIQ"
      },
      "source": [
        "We need to convert the saved model to TF1 as the base Universal Sentence Encoder is a TF1 model. It'll create a saved model dir on disk called `converted_model`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 10521,
          "status": "ok",
          "timestamp": 1651899539127,
          "user": {
            "displayName": "Zonglin Li",
            "userId": "11843710831668693042"
          },
          "user_tz": 240
        },
        "id": "jJV-44C0c_FK",
        "outputId": "ef8ac6ee-7d65-470d-b6de-897df5a466af"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Model saved to converted_model/\n"
          ]
        }
      ],
      "source": [
        "#@title Prepare the saved model\n",
        "!rm -rf converted_model\n",
        "\n",
        "# This create a new TF1 SavedModel from 1). The tfhub USE, and 2). The\n",
        "# projection layers trained and saved from TF2.\n",
        "with tf1.Graph().as_default() as g:\n",
        "  with tf1.Session() as sess:\n",
        "    # Reload the Universal Sentence Encoder model from tfhub. We can't just save\n",
        "    # the USE in TF2 as we did for the projection layers as that causes issues\n",
        "    # in the TFLite converter.\n",
        "    module = hub.Module(UNIVERSAL_SENTENCE_ENCODER_URL)\n",
        "    spm_path = sess.run(module(signature='spm_path'))\n",
        "    with tf1.io.gfile.GFile(spm_path, mode='rb') as f:\n",
        "      serialized_spm = f.read()\n",
        "    spm_path = sess.run(module(signature='spm_path'))\n",
        "    input_str = tf1.placeholder(dtype=tf1.string, shape=[None])\n",
        "    tokenizer = sentencepiece_tokenizer.FastSentencepieceTokenizer(\n",
        "        model=serialized_spm)\n",
        "    tokenized = tf1.sparse.from_dense(tokenizer.tokenize(input_str).to_tensor())\n",
        "    tokenized = tf1.cast(tokenized, dtype=tf1.int64)\n",
        "    encodings = module(\n",
        "        inputs=dict(\n",
        "            values=tokenized.values,\n",
        "            indices=tokenized.indices,\n",
        "            dense_shape=tokenized.dense_shape))\n",
        "\n",
        "    # Then combine it with the trained projection layers\n",
        "    projection_layers = tf1.keras.models.load_model('text_encoder_projection')\n",
        "    encodings = projection_layers(encodings)\n",
        "\n",
        "    sess.run([tf1.global_variables_initializer(), tf1.tables_initializer()])\n",
        "\n",
        "    # Save with SavedModelBuilder\n",
        "    builder = tf1.saved_model.Builder('converted_model')\n",
        "    sig_def = tf1.saved_model.predict_signature_def(\n",
        "        inputs={'input': input_str}, outputs={'output': encodings})\n",
        "    builder.add_meta_graph_and_variables(\n",
        "        sess,\n",
        "        tags=['serve'],\n",
        "        signature_def_map={\n",
        "            tf1.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def\n",
        "        },\n",
        "        clear_devices=True)\n",
        "    builder.save()\n",
        "print('Model saved to converted_model/')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XeS_H13j3KY_"
      },
      "source": [
        "Convert and save the TFLite model. Here the model only has the text encoder. We will add in the retrieval index in the following steps."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DPGs2kxbdGtK"
      },
      "outputs": [],
      "source": [
        "converter = tf.lite.TFLiteConverter.from_saved_model('converted_model')\n",
        "converter.experimental_new_converter = True\n",
        "converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]\n",
        "converter.allow_custom_ops = True\n",
        "converted_model_tflite = converter.convert()\n",
        "with open('text_embedder.tflite', 'wb') as f:\n",
        "  f.write(converted_model_tflite)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7gqnxcsDCTeq"
      },
      "source": [
        "### Create TFLite Searcher model\n",
        "\n",
        "In general see the documentation of [`ScaNNOptions`](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/ScaNNOptions) for how to configure the searcher for your dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z6peW6vvxMnF"
      },
      "outputs": [],
      "source": [
        "import tflite_model_maker as mm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bi6bkdnNVAiB"
      },
      "outputs": [],
      "source": [
        "scann_options = mm.searcher.ScaNNOptions(\n",
        "    # We use the dot product similarity as this is how the model is trained\n",
        "    distance_measure='dot_product',\n",
        "    # Enable space partitioning with K-Means tree\n",
        "    tree=mm.searcher.Tree(\n",
        "        # How many partitions to have. A rule of thumb is the square root of the\n",
        "        # dataset size. In this case it's 351.\n",
        "        num_leaves=int(math.sqrt(len(metadata))),\n",
        "        # Searching 4 partitions seems to give reasonable result. Searching more\n",
        "        # will definitely return better results, but it's more costly to run.\n",
        "        num_leaves_to_search=4),\n",
        "    # Compress each float to int8 in the embedding. See\n",
        "    # https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/searcher/ScoreAH\n",
        "    # for details\n",
        "    score_ah=mm.searcher.ScoreAH(\n",
        "        # Using 1 dimension per quantization block.\n",
        "        1,\n",
        "        # Generally 0.2 works pretty well.\n",
        "        anisotropic_quantization_threshold=0.2))\n",
        "\n",
        "data = mm.searcher.DataLoader(\n",
        "    embedder_path='text_embedder.tflite',\n",
        "    dataset=image_embeddings,\n",
        "    metadata=metadata)\n",
        "\n",
        "model = mm.searcher.Searcher.create_from_data(\n",
        "    data=data, scann_options=scann_options)\n",
        "model.export(\n",
        "    export_filename='searcher_model.tflite',\n",
        "    userinfo='',\n",
        "    export_format=mm.searcher.ExportFormat.TFLITE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EeZwqEnxW5Xl"
      },
      "source": [
        "## Run inference using Task Library"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "z-gkyy7vXRS0"
      },
      "outputs": [],
      "source": [
        "from tflite_support.task import text\n",
        "from tflite_support.task import core"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZQXqwY_X3eP4"
      },
      "source": [
        "Configure the searcher to return 6 results per query and not to L2 normalize the query embeddings because the text encoder has already normalized them. See [source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/python/task/text/text_searcher.py) on how to configure the `TextSearcher`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nJlkXpDsW8_5"
      },
      "outputs": [],
      "source": [
        "options = text.TextSearcherOptions(\n",
        "    base_options=core.BaseOptions(\n",
        "        file_name='searcher_model.tflite'))\n",
        "\n",
        "# The searcher returns 6 results\n",
        "options.search_options.max_results = 6\n",
        "\n",
        "tflite_searcher = text.TextSearcher.create_from_options(options)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ggwCRyT_kQGs"
      },
      "outputs": [],
      "source": [
        "def search_image_with_text(query_str, show_images=False):\n",
        "  neighbors = tflite_searcher.search(query_str)\n",
        "\n",
        "  for i, neighbor in enumerate(neighbors.nearest_neighbors):\n",
        "    metadata = neighbor.metadata.decode('utf-8').split('_')\n",
        "    flickr_id = metadata[0]\n",
        "    print('Flickr url for %d: http://flickr.com/photo.gne?id=%s' %\n",
        "          (i + 1, flickr_id))\n",
        "\n",
        "  if show_images:\n",
        "    plt.figure(figsize=(20, 13))\n",
        "    for i, neighbor in enumerate(neighbors.nearest_neighbors):\n",
        "      ax = plt.subplot(2, 3, i + 1)\n",
        "\n",
        "      # Using negative distance since on-device ScaNN returns negative\n",
        "      # dot-product distance.\n",
        "      ax.set_title('%d: Similarity: %.05f' % (i + 1, -neighbor.distance))\n",
        "      metadata = neighbor.metadata.decode('utf-8').split('_')\n",
        "      image_path = '_'.join(metadata[1:])\n",
        "      image = tf.image.decode_jpeg(\n",
        "          tf.io.read_file(image_path), channels=3) / 255\n",
        "      plt.imshow(image)\n",
        "      plt.axis('off')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JGAsS4mQ3dnX"
      },
      "source": [
        "We will not show the image here due to copyright issues. You can set `show_images=True` to display them (note that you can't set it to `True` unless you've downloaded the images at [this cell](#scrollTo=Ke6EeKAqj1vB\u0026line=12\u0026uniqifier=1)). Please check the post links for the license of each image."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 172,
          "status": "ok",
          "timestamp": 1651934792149,
          "user": {
            "displayName": "Zonglin Li",
            "userId": "11843710831668693042"
          },
          "user_tz": 240
        },
        "id": "v7g0RmYjks9i",
        "outputId": "18a154fd-884d-4ad1-9498-416691831758"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Flickr url for 1: http://flickr.com/photo.gne?id=6388219123\n",
            "Flickr url for 2: http://flickr.com/photo.gne?id=30100145\n",
            "Flickr url for 3: http://flickr.com/photo.gne?id=3322126404\n",
            "Flickr url for 4: http://flickr.com/photo.gne?id=4945223078\n",
            "Flickr url for 5: http://flickr.com/photo.gne?id=120446248\n",
            "Flickr url for 6: http://flickr.com/photo.gne?id=4807048033\n"
          ]
        }
      ],
      "source": [
        "search_image_with_text('A man riding on a bike')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9w2hEIF768be"
      },
      "source": [
        "Congratulations on finishing this colab! For next steps, you can try deploy the model on-device (inference + search on Pixel 6 is around 6 ms), or you can train the model with your own dataset. In the mean time, don't forget to checkout our documentations ([Model Maker](https://www.tensorflow.org/lite/guide/model_maker/), [Task Library](https://www.tensorflow.org/lite/inference_with_metadata/task_library/text_searcher/)) and the [reference app](https://github.com/tensorflow/examples/tree/master/lite/examples/text_searcher/android), which searches news articles in [CNN_DailyMail dataset](https://www.tensorflow.org/datasets/catalog/cnn_dailymail)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "On-device Text-to-Image Search with TensorFlow Lite Searcher Library",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
